"""Generate migrations utilities."""

import ast
import inspect
import os
import pathlib
from pathlib import Path
from types import ModuleType

from typing_extensions import override

HERE = Path(__file__).parent
# Should bring us to [root]/src
PKGS_ROOT = HERE.parent.parent.parent.parent.parent

LANGCHAIN_PKG = PKGS_ROOT / "langchain"
COMMUNITY_PKG = PKGS_ROOT / "community"
PARTNER_PKGS = PKGS_ROOT / "partners"


class ImportExtractor(ast.NodeVisitor):
    """Import extractor."""

    def __init__(self, *, from_package: str | None = None) -> None:
        """Extract all imports from the given code, optionally filtering by package."""
        self.imports: list[tuple[str, str]] = []
        self.package = from_package

    @override
    def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
        if node.module and (
            self.package is None or str(node.module).startswith(self.package)
        ):
            for alias in node.names:
                self.imports.append((node.module, alias.name))
        self.generic_visit(node)


def _get_class_names(code: str) -> list[str]:
    """Extract class names from a code string."""
    # Parse the content of the file into an AST
    tree = ast.parse(code)

    # Initialize a list to hold all class names
    class_names = []

    # Define a node visitor class to collect class names
    class ClassVisitor(ast.NodeVisitor):
        @override
        def visit_ClassDef(self, node: ast.ClassDef) -> None:
            class_names.append(node.name)
            self.generic_visit(node)

    # Create an instance of the visitor and visit the AST
    visitor = ClassVisitor()
    visitor.visit(tree)
    return class_names


def is_subclass(class_obj: type, classes_: list[type]) -> bool:
    """Check if the given class object is a subclass of any class in list classes.

    Args:
        class_obj: The class to check.
        classes_: A list of classes to check against.

    Returns:
        True if `class_obj` is a subclass of any class in `classes_`, `False` otherwise.
    """
    return any(
        issubclass(class_obj, kls)
        for kls in classes_
        if inspect.isclass(class_obj) and inspect.isclass(kls)
    )


def find_subclasses_in_module(module: ModuleType, classes_: list[type]) -> list[str]:
    """Find all classes in the module that inherit from one of the classes.

    Args:
        module: The module to inspect.
        classes_: A list of classes to check against.

    Returns:
        A list of class names that are subclasses of any class in `classes_`.
    """
    subclasses = []
    # Iterate over all attributes of the module that are classes
    for _name, obj in inspect.getmembers(module, inspect.isclass):
        if is_subclass(obj, classes_):
            subclasses.append(obj.__name__)
    return subclasses


def _get_all_classnames_from_file(file: Path, pkg: str) -> list[tuple[str, str]]:
    """Extract all class names from a file."""
    code = Path(file).read_text(encoding="utf-8")
    module_name = _get_current_module(file, pkg)
    class_names = _get_class_names(code)

    return [(module_name, class_name) for class_name in class_names]


def identify_all_imports_in_file(
    file: str,
    *,
    from_package: str | None = None,
) -> list[tuple[str, str]]:
    """Identify all the imports in the given file.

    Args:
        file: The file to analyze.
        from_package: If provided, only return imports from this package.

    Returns:
        A list of tuples `(module, name)` representing the imports found in the file.
    """
    code = Path(file).read_text(encoding="utf-8")
    return find_imports_from_package(code, from_package=from_package)


def identify_pkg_source(pkg_root: str) -> pathlib.Path:
    """Identify the source of the package.

    Args:
        pkg_root: the root of the package. This contains source + tests, and other
            things like pyproject.toml, lock files etc

    Returns:
        Returns the path to the source code for the package.

    Raises:
        ValueError: If there is not exactly one directory starting with `'langchain_'`
            in the package root.
    """
    dirs = [d for d in Path(pkg_root).iterdir() if d.is_dir()]
    matching_dirs = [d for d in dirs if d.name.startswith("langchain_")]
    if len(matching_dirs) != 1:
        msg = "There should be only one langchain package."
        raise ValueError(msg)
    return matching_dirs[0]


def list_classes_by_package(pkg_root: str) -> list[tuple[str, str]]:
    """List all classes in a package.

    Args:
        pkg_root: the root of the package.

    Returns:
        A list of tuples `(module, class_name)` representing all classes found in the
        package, excluding test files.
    """
    module_classes = []
    pkg_source = identify_pkg_source(pkg_root)
    files = list(pkg_source.rglob("*.py"))

    for file in files:
        rel_path = os.path.relpath(file, pkg_root)
        if rel_path.startswith("tests"):
            continue
        module_classes.extend(_get_all_classnames_from_file(file, pkg_root))
    return module_classes


def list_init_imports_by_package(pkg_root: str) -> list[tuple[str, str]]:
    """List all the things that are being imported in a package by module.

    Args:
        pkg_root: the root of the package.

    Returns:
        A list of tuples `(module, name)` representing the imports found in
        `__init__.py` files.
    """
    imports = []
    pkg_source = identify_pkg_source(pkg_root)
    # Scan all the files in the package
    files = list(Path(pkg_source).rglob("*.py"))

    for file in files:
        if file.name != "__init__.py":
            continue
        import_in_file = identify_all_imports_in_file(str(file))
        module_name = _get_current_module(file, pkg_root)
        imports.extend([(module_name, item) for _, item in import_in_file])
    return imports


def find_imports_from_package(
    code: str,
    *,
    from_package: str | None = None,
) -> list[tuple[str, str]]:
    """Find imports in code.

    Args:
        code: The code to analyze.
        from_package: If provided, only return imports from this package.

    Returns:
        A list of tuples `(module, name)` representing the imports found.
    """
    # Parse the code into an AST
    tree = ast.parse(code)
    # Create an instance of the visitor
    extractor = ImportExtractor(from_package=from_package)
    # Use the visitor to update the imports list
    extractor.visit(tree)
    return extractor.imports


def _get_current_module(path: Path, pkg_root: str) -> str:
    """Convert a path to a module name."""
    relative_path = path.relative_to(pkg_root).with_suffix("")
    posix_path = relative_path.as_posix()
    norm_path = os.path.normpath(str(posix_path))
    fully_qualified_module = norm_path.replace("/", ".")
    # Strip __init__ if present
    if fully_qualified_module.endswith(".__init__"):
        return fully_qualified_module[:-9]
    return fully_qualified_module
